{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create Your Own Visualizations!\n",
    "Instructions:\n",
    "1. Install tensor2tensor and train up a Transformer model following the instruction in the repository https://github.com/tensorflow/tensor2tensor.\n",
    "2. Update cell 3 to point to your checkpoint, it is currently set up to read from the default checkpoint location that would be created from following the instructions above.\n",
    "3. If you used custom hyper parameters then update cell 4.\n",
    "4. Run the notebook!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import json\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "from tensor2tensor.tpu import tpu_trainer_lib\n",
    "from tensor2tensor.utils import t2t_model\n",
    "from tensor2tensor.utils import decoding\n",
    "from tensor2tensor.utils import devices\n",
    "from tensor2tensor.visualization import attention\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "require.config({\n",
       "  paths: {\n",
       "      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min'\n",
       "  }\n",
       "});"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%javascript\n",
    "require.config({\n",
    "  paths: {\n",
    "      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min'\n",
    "  }\n",
    "});"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# PUT THE MODEL YOU WANT TO LOAD HERE!\n",
    "\n",
    "PROBLEM = 'translate_ende_wmt32k'\n",
    "MODEL = 'transformer'\n",
    "HPARAMS = 'transformer_base_single_gpu'\n",
    "\n",
    "DATA_DIR=os.path.expanduser('~/t2t_data')\n",
    "TRAIN_DIR=os.path.expanduser('~/t2t_train/%s/%s-%s' % (PROBLEM, MODEL, HPARAMS))\n",
    "print(TRAIN_DIR)\n",
    "\n",
    "FLAGS = tf.flags.FLAGS\n",
    "FLAGS.problems = PROBLEM\n",
    "FLAGS.hparams_set = HPARAMS\n",
    "FLAGS.data_dir = DATA_DIR\n",
    "FLAGS.model = MODEL\n",
    "\n",
    "FLAGS.schedule = 'train_and_evaluate'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:datashard_devices: ['gpu:0']\n",
      "INFO:tensorflow:caching_devices: None\n",
      "INFO:tensorflow:batching_scheme = {'min_length': 0, 'window_size': 720, 'shuffle_queue_size': 270, 'boundaries': [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 24, 26, 28, 30, 33, 36, 39, 42, 46, 50, 55, 60, 66, 72, 79, 86, 94, 103, 113, 124, 136, 149, 163, 179, 196, 215, 236], 'max_length': 1000000000, 'batch_sizes': [240, 180, 180, 180, 144, 144, 144, 120, 120, 120, 90, 90, 90, 90, 80, 72, 72, 60, 60, 48, 48, 48, 40, 40, 36, 30, 30, 24, 24, 20, 20, 18, 18, 16, 15, 12, 12, 10, 10, 9, 8, 8]}\n",
      "INFO:tensorflow:Updated batching_scheme = {'min_length': 0, 'window_size': 720, 'shuffle_queue_size': 270, 'boundaries': [], 'max_length': 1000000000, 'batch_sizes': [1]}\n",
      "INFO:tensorflow:Reading data files from /usr/local/google/home/llion/t2t_data/translate_ende_wmt32k-dev*\n"
     ]
    }
   ],
   "source": [
    "hparams = tpu_trainer_lib.create_hparams(FLAGS.hparams_set, data_dir=FLAGS.data_dir, problem_name=PROBLEM)\n",
    "hparams.use_fixed_batch_size = True\n",
    "hparams.batch_size = 1\n",
    "\n",
    "# SET EXTRA HYPER PARAMS HERE!\n",
    "#hparams.null_slot = True\n",
    "\n",
    "mode = tf.estimator.ModeKeys.EVAL\n",
    "\n",
    "problem = hparams.problem_instances[0]\n",
    "inputs, target = problem.input_fn(\n",
    "      mode=mode,\n",
    "      hparams=hparams,\n",
    "      data_dir=DATA_DIR)\n",
    "\n",
    "features = inputs\n",
    "features['targets'] = target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def encode(string):\n",
    "    subtokenizer = hparams.problems[0].vocabulary['inputs']\n",
    "    return [subtokenizer.encode(string) + [1] + [0]]\n",
    "\n",
    "def decode(ids):\n",
    "    return hparams.problems[0].vocabulary['targets'].decode(np.squeeze(ids))\n",
    "\n",
    "def to_tokens(ids):\n",
    "    ids = np.squeeze(ids)\n",
    "    subtokenizer = hparams.problems[0].vocabulary['targets']\n",
    "    tokens = []\n",
    "    for _id in ids:\n",
    "        if _id == 0:\n",
    "            tokens.append('<PAD>')\n",
    "        elif _id == 1:\n",
    "            tokens.append('<EOS>')\n",
    "        else:\n",
    "            tokens.append(subtokenizer._subtoken_id_to_subtoken_string(_id))\n",
    "    return tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:datashard_devices: ['gpu:0']\n",
      "INFO:tensorflow:caching_devices: None\n",
      "INFO:tensorflow:Doing model_fn_body took 1.881 sec.\n",
      "INFO:tensorflow:This model_fn took 2.023 sec.\n"
     ]
    }
   ],
   "source": [
    "decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)\n",
    "model_fn = t2t_model.T2TModel.make_estimator_model_fn(\n",
    "    MODEL,\n",
    "    hparams,\n",
    "    decode_hparams=decode_hparams)\n",
    "est_spec = model_fn(features, target, mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:datashard_devices: ['gpu:0']\n",
      "INFO:tensorflow:caching_devices: None\n",
      "INFO:tensorflow:Beam Decoding with beam size 4\n",
      "INFO:tensorflow:Doing model_fn_body took 1.393 sec.\n",
      "INFO:tensorflow:This model_fn took 1.504 sec.\n"
     ]
    }
   ],
   "source": [
    "with tf.variable_scope(tf.get_variable_scope(), reuse=True):\n",
    "    beam_out = model_fn(features, target, tf.contrib.learn.ModeKeys.INFER)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from /usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu/model.ckpt-1\n",
      "INFO:tensorflow:Starting standard services.\n",
      "INFO:tensorflow:Starting queue runners.\n",
      "INFO:tensorflow:Saving checkpoint to path /usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu/model.ckpt\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sv = tf.train.Supervisor(\n",
    "    logdir=TRAIN_DIR,\n",
    "    global_step=tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step'))\n",
    "sess = sv.PrepareSession(config=tf.ConfigProto(allow_soft_placement=True))\n",
    "sv.StartQueueRunners(\n",
    "    sess,\n",
    "    tf.get_default_graph().get_collection(tf.GraphKeys.QUEUE_RUNNERS))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Get the attention tensors from the graph.\n",
    "# This need to be done using the training graph since the inference uses a tf.while_loop\n",
    "# and you cant fetch tensors from inside a while_loop.\n",
    "\n",
    "enc_atts = []\n",
    "dec_atts = []\n",
    "encdec_atts = []\n",
    "\n",
    "for i in range(hparams.num_hidden_layers):\n",
    "    enc_att = tf.get_default_graph().get_operation_by_name(\n",
    "        \"body/model/parallel_0/body/encoder/layer_%i/self_attention/multihead_attention/dot_product_attention/attention_weights\" % i).values()[0]\n",
    "    dec_att = tf.get_default_graph().get_operation_by_name(\n",
    "        \"body/model/parallel_0/body/decoder/layer_%i/self_attention/multihead_attention/dot_product_attention/attention_weights\" % i).values()[0]\n",
    "    encdec_att = tf.get_default_graph().get_operation_by_name(\n",
    "        \"body/model/parallel_0/body/decoder/layer_%i/encdec_attention/multihead_attention/dot_product_attention/attention_weights\" % i).values()[0]\n",
    "\n",
    "    enc_atts.append(enc_att)\n",
    "    dec_atts.append(dec_att)\n",
    "    encdec_atts.append(encdec_att)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test translation from the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:global_step/sec: 0\n",
      "Input:     For example, during the 2008 general election in Florida, 33% of early voters were African-Americans, who accounted however for only 13% of voters in the State.\n",
      "Gold:      Beispielsweise waren bei den allgemeinen Wahlen 2008 in Florida 33% der Wähler, die im Voraus gewählt haben, Afro-Amerikaner, obwohl sie nur 13% der Wähler des Bundesstaates ausmachen.\n",
      "Gold out:  So waren 33 den allgemeinen Wahlen im in der a 33 % der Frühjungdie nur Land die wurden, die ro- Amerikaner, die sie nur 13 % der Wähler im Staates staats betra.\n",
      "INFO:tensorflow:Recording summary at step 250000.\n"
     ]
    }
   ],
   "source": [
    "inp, out, logits = sess.run([inputs['inputs'], target, est_spec.predictions['predictions']])\n",
    "\n",
    "print(\"Input:    \", decode(inp[0]))\n",
    "print(\"Gold:     \", decode(out[0]))\n",
    "logits = np.squeeze(logits[0])\n",
    "tokens = np.argmax(logits, axis=1)\n",
    "print(\"Gold out: \", decode(tokens))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize Custom Sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "eng = \"I have three dogs.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ich habe drei Hunde.\n"
     ]
    }
   ],
   "source": [
    "inp_ids = encode(eng)\n",
    "beam_decode = sess.run(beam_out.predictions['outputs'], {\n",
    "    inputs['inputs']: np.expand_dims(np.expand_dims(inp_ids, axis=2), axis=3),\n",
    "})\n",
    "trans = decode(beam_decode[0])\n",
    "print(trans)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "output_ids = beam_decode\n",
    "\n",
    "# Get attentions\n",
    "np_enc_atts, np_dec_atts, np_encdec_atts = sess.run([enc_atts, dec_atts, encdec_atts], {\n",
    "    inputs['inputs']: np.expand_dims(np.expand_dims(inp_ids, axis=2), axis=3),\n",
    "    target: np.expand_dims(np.expand_dims(output_ids, axis=2), axis=3),\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
       "    return false;\n",
       "}"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%javascript\n",
    "IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
    "    return false;\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interpreting the Visualizations\n",
    "- The layers drop down allow you to view the different Transformer layers, 0-indexed of course.\n",
    "  - Tip: The first layer, last layer and 2nd to last layer are usually the most interpretable.\n",
    "- The attention dropdown allows you to select different pairs of encoder-decoder attentions:\n",
    "  - All: Shows all types of attentions together. NOTE: There is no relation between heads of the same color - between the decoder self attention and decoder-encoder attention since they do not share parameters.\n",
    "  - Input - Input: Shows only the encoder self-attention.\n",
    "  - Input - Output: Shows the decoder’s attention on the encoder. NOTE: Every decoder layer attends to the final layer of encoder so the visualization will show the attention on the final encoder layer regardless of what layer is selected in the drop down.\n",
    "  - Output - Output: Shows only the decoder self-attention. NOTE: The visualization might be slightly misleading in the first layer since the text shown is the target of the decoder, the input to the decoder at layer 0 is this text with a GO symbol prepreded.\n",
    "- The colored squares represent the different attention heads.\n",
    "  - You can hide or show a given head by clicking on it’s color.\n",
    "  - Double clicking a color will hide all other colors, double clicking on a color when it’s the only head showing will show all the heads again.\n",
    "- You can hover over a word to see the individual attention weights for just that position.\n",
    "  - Hovering over the words on the left will show what that position attended to.\n",
    "  - Hovering over the words on the right will show what positions attended to it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "inp_text = to_tokens(inp_ids)\n",
    "out_text = to_tokens(output_ids)\n",
    "\n",
    "attention.show(inp_text, out_text, np_enc_atts, np_dec_atts, np_encdec_atts)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
